
```{r boilerplate}

library(augsynth)
library(tidyverse)
library(xtable)
library(scales)
library(ggrepel)

knitr::opts_chunk$set(dev = "pdf")
```


```{r load_and_clean_data}
nonfix_60_2006 <- read.csv("augsyn_nonfix_60_2006_2020_07_30.csv")

# # fix issue with sped in Ypsilanti 
nonfix_60_2006$sped[nonfix_60_2006$leaid == 81020 & nonfix_60_2006$year == 2007] <- 
  mean(c(nonfix_60_2006$sped[nonfix_60_2006$leaid == 81020 & nonfix_60_2006$year == 2006],
       nonfix_60_2006$sped[nonfix_60_2006$leaid == 81020 & nonfix_60_2006$year == 2008]))


```


```{r flint_svd,  fig.width = 6.4, fig.height = 4.2}

### Verify the low rank structure
long.data.untreat <- nonfix_60_2006[nonfix_60_2006$year<2015, c('leaid','year','math','read','sped','attd')]
# demean across years for each county and each outcome
long.data <- long.data.untreat %>%
  group_by(leaid) %>%
  mutate(across(c('math','read','sped','attd'), ~ .x - mean(.x,na.rm=T), .names = "tdm_{col}")) %>%
  select(c('leaid','year','tdm_math','tdm_read','tdm_sped','tdm_attd'))  

long.data <- as.data.frame(long.data)
short.data <- reshape(long.data, idvar = 'leaid', timevar = 'year', direction = "wide")
# Remove columns that are all NAs
short.data <- short.data[, colSums(is.na(short.data)) != nrow(short.data)]
short.data <- na.omit(short.data) %>%
  mutate_at(vars(starts_with("tdm_")), ~scale(.)) # remove the year mean for each outcome
short.data.values <- short.data  %>%
  select(starts_with("tdm_"))
r.svd <- svd(short.data.values)
r.svd$cumulative.var <-  cumsum(r.svd$d^2)
par(mfrow = c(1, 2))
plot(1:length(r.svd$d),  r.svd$d , 
     xlab = "", ylab = "Ordered Singular Values", 
     type = "b", pch = 19)
plot(r.svd$cumulative.var/sum(r.svd$d^2),xlab ="", ylab = "Cumulative prop. of variance explained", type = "b",  pch = 19)

```


```{r fit_synths}


syn_sep_math <- augsynth(math ~ treat, leaid, year, nonfix_60_2006 %>% filter(year >= 2007),
                         progfunc="None", fixedeff="TRUE", scm=T)

syn_sep_read <- augsynth(read ~ treat, leaid, year,nonfix_60_2006  %>% filter(year >= 2007),
                         progfunc="None", fixedeff="TRUE", scm=T)

syn_sep_sped <- augsynth(sped ~ treat, leaid ,year,nonfix_60_2006  %>% filter(year >= 2007),
                         progfunc="None", fixedeff="TRUE", scm=T)

syn_sep_attd <- augsynth(attd ~ treat, leaid ,year,nonfix_60_2006  %>% filter(year >= 2009),
                         progfunc="None", fixedeff="TRUE", scm=T)

syn_cat <- augsynth(math + read + sped + attd ~ treat, leaid ,year,nonfix_60_2006 %>% filter(year >= 2007),
                    progfunc = "None", scm = T,combine_method = "concat" , fixedeff = T)

syn_avg <-  augsynth(math + read + sped + attd ~ treat, leaid ,year,nonfix_60_2006 %>% filter(year >= 2007),
                     progfunc="None", fixedeff="TRUE", scm=T,combine_method = "avg")



Xs <- augsynth:::combine_outcomes(syn_cat$data_list, "avg_concat", T, nu = 0)$wide_bal$X
trt <- syn_cat$data_list$trt

X1_avg <- Xs[trt == 1, 1:8]
X0_avg <- Xs[trt == 0, 1:8]
X1_cat <- Xs[trt == 1, 9:ncol(Xs)]
X0_cat <- Xs[trt == 0, 9:ncol(Xs)]

# compute herustic nu
wts_cat <- syn_cat$weights

avg_obj <- mean((X1_avg - t(X0_avg) %*% wts_cat)^2)
cat_obj <- mean((X1_cat - t(X0_cat) %*% wts_cat)^2)

nu_heur <- sqrt(avg_obj) / sqrt(cat_obj)


syn_both <-  augsynth(math + read + sped + attd ~ treat, leaid ,year,nonfix_60_2006 %>% filter(year >= 2007),
                     progfunc="None", fixedeff=T, scm=T,combine_method = "avg_concat", nu = nu_heur)



## Summary of results
summ_cat <- summary(syn_cat)
summ_avg <- summary(syn_avg)
summ_both <- summary(syn_both)

summ_math <- summary(syn_sep_math)
summ_math$att$Outcome <- "math"

summ_read <- summary(syn_sep_read)
summ_read$att$Outcome <- "read"

summ_sped <- summary(syn_sep_sped)
summ_sped$att$Outcome <- "sped"

summ_attd <- summary(syn_sep_attd)
summ_attd$att$Outcome <- "attd"

summ <- summary(syn_sep_math)
summ$att <- summ_math$att
summ$att <- rbind(summ$att,summ_read$att)
summ$att <- rbind(summ$att,summ_sped$att)
summ$att <- rbind(summ$att,summ_attd$att)



att_cleanup <- function(m) {
  att <- m$att %>%
    select(c("Time","Outcome","Estimate")) %>%
    filter(Outcome != "Average") %>%
    mutate(Outcome = as_factor(Outcome),
           Outcome = fct_relevel(Outcome, c('math', 'read', 'sped', 'attd')),
           Outcome_label = fct_recode(Outcome,
                                      `Math Achievement` = 'math',
                                      `Reading Achievement`='read',
                                      `Special Needs`='sped',
                                      `Student Attendance`='attd'))
  return(att)
}

att_sep <- att_cleanup(summ)
att_cat <- att_cleanup(summ_cat)
att_avg <- att_cleanup(summ_avg)
att_both <- att_cleanup(summ_both)
```



```{r figure_flint_summary_gap_annual_fixedeff_sped_orig_order, fig.width = 7.8, fig.height = 5.4}

bounds <- read.table(header=TRUE,
                     text=
                       "Outcome ymin ymax
                     Math        -.15  .15
                     Reading     -.15  .15
                     Special     -.02 .02
                     Attendance  -.02 .02",
                     stringsAsFactors=FALSE)

bounds$Outcome_label <- c("Math Achievement", "Reading Achievement", "Special Needs", "Student Attendance")



# The palette with black:
cbp2 <- c(
          # "#000000",
           "#009E73", "#E69F00", "#56B4E9",
          # "#F0E442", "#0072B2",
          "#D55E00", "#CC79A7")
# Create a factor variable to represent different categories
att_sep$category <- "Separate"
att_cat$category <- "Concatenate"
att_avg$category <- "Average"
att_both$category <- "Combined"

# Combine the dataframes
combined_data <- rbind(att_sep, att_cat, att_avg, att_both)
combined_data$treated <- combined_data$Time >= 2015
combined_data[combined_data$Outcome == 'sped',]$Estimate <- combined_data[combined_data$Outcome == 'sped',]$Estimate
# Add the scaling
bounds_expand <- left_join(combined_data,bounds,by="Outcome_label")
# Create the plot

ggplot(combined_data, aes(x = Time, y = Estimate, group = interaction(category,treated),
                                  color = category)) +
  geom_line() +
  geom_vline(xintercept = 2014.5, linetype = "dashed") +
  geom_hline(yintercept = 0, linetype = "dashed") +
  geom_line(data = combined_data[combined_data$Time < 2015,], size = 1) +
  geom_line(data = combined_data[combined_data$Time >= 2015,], size = 1) +
  labs(x = "Year",
       y = "Difference Between \n Observed Flint & Synthetic Flint") + 
  facet_wrap(vars(Outcome_label), scales="free") +
  geom_blank(data = bounds_expand, aes(x=Time, y = ymin)) +
  geom_blank(data = bounds_expand, aes(x=Time, y = ymax)) +
  scale_x_continuous(breaks= pretty_breaks()) + 
  theme_bw() +
  theme(panel.border = element_blank(),
        panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        axis.line = element_line(colour = "black"),
        legend.position = "bottom",
        legend.direction = "horizontal",
        legend.title=element_blank())+
  scale_colour_manual(values=cbp2)
```


```{r fit_multiple_nu, cache = TRUE}


nus <- seq(0, 1, length.out = 30)



syn_boths <- lapply(nus, 
  function(nu) augsynth(math + read + sped + attd ~ treat, leaid ,year,nonfix_60_2006 %>% filter(year >= 2007),
                    progfunc = "None", scm = T,combine_method = "avg_concat" , fixedeff = T, nu = nu))


Xs <- augsynth:::combine_outcomes(syn_boths[[1]]$data_list, "avg_concat", T, nu = 0)$wide_bal$X
trt <- syn_boths[[1]]$data_list$trt

X1_avg <- Xs[trt == 1, 1:8]
X0_avg <- Xs[trt == 0, 1:8]
X1_cat <- Xs[trt == 1, 9:ncol(Xs)]
X0_cat <- Xs[trt == 0, 9:ncol(Xs)]

# compute herustic nu
wts_cat <- syn_cat$weights

avg_obj <- mean((X1_avg - t(X0_avg) %*% wts_cat)^2)
cat_obj <- mean((X1_cat - t(X0_cat) %*% wts_cat)^2)

nu_heur <- sqrt(avg_obj) / sqrt(cat_obj)

nus <- c(nus, nu_heur)

syn_boths <- c(
  syn_boths,
  list(augsynth(math + read + sped + attd ~ treat, leaid ,year,nonfix_60_2006 %>% filter(year >= 2007),
                    progfunc = "None", scm = T,combine_method = "avg_concat" , fixedeff = T, nu = nu_heur))
)


syn_summ_boths <- lapply(syn_boths, summary)


```


```{r frontier_plot, fig.width = 7.8, fig.height = 5.4}

map_df(1:length(nus),
  function(i) {
    nu <- nus[i]
    wts <- syn_boths[[i]]$weights

    avg_obj <- mean((X1_avg - t(X0_avg) %*% wts)^2)
    cat_obj <- mean((X1_cat - t(X0_cat) %*% wts)^2)
    return(data.frame(nu = nu, avg = avg_obj, cat = cat_obj))
  }) %>%
  mutate(label = case_when(nu == 0 ~ "Concatenated Weights",
                           nu == 1 ~ "Average Weights",
                           nu == nu_heur ~ "Heuristic Choice",
                           TRUE ~ NA)) %>%
  ggplot(aes(x = cat, y = avg)) +
  geom_line() +
  geom_point(data = . %>% filter(nu %in% c(0, 1, nu_heur))) +
  geom_label_repel(aes(label = label), nudge_x = 0.005, nudge_y = 0.0005) +
  xlab("Concatenated Objective") +
  ylab("Average Objective") +
  theme_bw()

```


```{r p_val_plot, fig.width = 7.8, fig.height = 5.4}

map_df(1:length(nus),
  function(i) {
    nu <- nus[i]
    syn_summ_boths[[i]]$att %>% distinct(Time, p_val) %>% na.omit() %>%
      mutate(nu = nu, Time = as.character(Time))
  }) %>%
bind_rows(
  map_df(1:length(nus),
  function(i) {
    nu <- nus[i]
    data.frame(p_val = syn_summ_boths[[i]]$average_att$p_val[1], nu = nu, Time ="All")
  }) 
) %>%
  ggplot(aes(x = nu, y = p_val, color = as.factor(Time)))+
  geom_line() +
  geom_hline(yintercept = 0.1, lty = 2) +
  geom_vline(xintercept = nu_heur, lty = 3) +
  ylab("p value") +
  xlab(expression(nu)) +
  scale_color_brewer("Post-Intervention Year", type = "qual", palette = "Dark2") +
  theme_bw() +
  theme(legend.position = "bottom")
```



```{r removing_sped}


syn_cat_nosped <- augsynth(math + read + attd ~ treat, leaid ,year,nonfix_60_2006 %>% filter(year >= 2007),
                    progfunc = "None", scm = T,combine_method = "concat" , fixedeff = T)

syn_avg_nosped <-  augsynth(math + read + attd ~ treat, leaid ,year,nonfix_60_2006 %>% filter(year >= 2007),
                     progfunc="None", fixedeff="TRUE", scm=T,combine_method = "avg")


Xs_nosped <- augsynth:::combine_outcomes(syn_cat_nosped$data_list, "avg_concat", T, nu = 0)$wide_bal$X

X1_avg_nosped <- Xs_nosped[trt == 1, 1:8]
X0_avg_nosped <- Xs_nosped[trt == 0, 1:8]
X1_cat_nosped <- Xs_nosped[trt == 1, 9:ncol(Xs_nosped)]
X0_cat_nosped <- Xs_nosped[trt == 0, 9:ncol(Xs_nosped)]

# compute herustic nu
wts_cat_nosped <- syn_cat_nosped$weights

avg_obj_nosped <- mean((X1_avg_nosped - t(X0_avg_nosped) %*% wts_cat_nosped)^2)
cat_obj_nosped <- mean((X1_cat_nosped - t(X0_cat_nosped) %*% wts_cat_nosped)^2)

nu_heur_nosped <- sqrt(avg_obj_nosped) / sqrt(cat_obj_nosped)


syn_both_nosped <-  augsynth(math + read + attd ~ treat, leaid ,year,nonfix_60_2006 %>% filter(year >= 2007),
                     progfunc="None", fixedeff=T, scm=T,combine_method = "avg_concat", nu = nu_heur_nosped)



## Summary of results
summ_cat_nosped <- summary(syn_cat_nosped)
summ_avg_nosped <- summary(syn_avg_nosped)
summ_both_nosped <- summary(syn_both_nosped)

att_cat_nosped <- att_cleanup(summ_cat_nosped)
att_avg_nosped <- att_cleanup(summ_avg_nosped)
att_both_nosped <- att_cleanup(summ_both_nosped)


# Create a factor variable to represent different categories
att_cat_nosped$category <- "Concatenate"
att_avg_nosped$category <- "Average"
att_both_nosped$category <- "Combined"

```

```{r flint_nosped, fig.width = 7.5, fig.height = 3.5}


# Combine the dataframes
combined_data <- rbind(att_sep %>% filter(Outcome_label != "Special Needs"), att_cat_nosped, att_avg_nosped, att_both_nosped)
combined_data$treated <- combined_data$Time >= 2015
combined_data[combined_data$Outcome == 'sped',]$Estimate <- -combined_data[combined_data$Outcome == 'sped',]$Estimate
# Add the scaling
bounds_expand <- left_join(combined_data,bounds,by="Outcome_label")
# Create the plotfigure.name <- "flint_nosped.png"

ggplot(combined_data, aes(x = Time, y = Estimate, group = interaction(category,treated),
                                  color = category)) +
  geom_line(size = 1) +
  geom_vline(xintercept = 2014.5, linetype = "dashed") +
  geom_hline(yintercept = 0, linetype = "dashed") +
  geom_line(data = combined_data[combined_data$Time < 2015,], size = 1) +
  geom_line(data = combined_data[combined_data$Time >= 2015,], size = 1) +
  labs(x = "Year",
       y = "Difference Between \n Observed Flint & Synthetic Flint") + 
  facet_wrap(~Outcome_label, scales="free") +
  geom_blank(data = bounds_expand, aes(x=Time, y = ymin)) +
  geom_blank(data = bounds_expand, aes(x=Time, y = ymax)) +
  scale_x_continuous(breaks= pretty_breaks()) + 
  theme_bw() +
  theme(panel.border = element_blank(),
        panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        axis.line = element_line(colour = "black"),
        legend.position = "bottom",
        legend.direction = "horizontal",
        legend.title=element_blank())+
  scale_colour_manual(values=cbp2)

```



```{r all_scm_weights, fig.height = 4, fig.width = 6, message = FALSE, warning=FALSE, dpi=300}

# label weights by outcome
all_weights <- data.frame(
  district = rownames(syn_sep_math$weights),
  `Math Achievement` = syn_sep_math$weights,
  `Reading Achievement` = syn_sep_read$weights,
  `Special Needs` = syn_sep_sped$weights,
  `Student Attendance` = syn_sep_attd$weights,
  `Concatenated Objective` = syn_cat$weights,
  `Averaged Objective` = syn_avg$weights,
  `Combined Objective` = syn_both$weights
)

all_weights %>%
  mutate(district = fct_rev(fct_reorder(district, Combined.Objective))) %>%
  pivot_longer(-district) %>%
  mutate(name = str_replace(name, "\\.", " "),
         value = pmax(value, 0)) %>%
  mutate(name = fct_relevel(name, rev(c("Math Achievement", "Reading Achievement", "Student Attendance", "Special Needs", "Averaged Objective", "Concatenated Objective",
         "Combined Objective")))) %>%
  ggplot(aes(x=name, y=district, fill = value)) + 
  geom_tile(color = "white", size=.5) +
  scale_fill_gradient("SCM Weight", low = "white", high = "black", labels = scales::percent) +
  guides(fill = guide_legend(title.position = "top", override.aes = list(color = "black"))) +
  xlab("") +
  ylab("Donor District") +
  theme_bw() +
  scale_x_discrete(expand=c(0,0))+
  scale_y_discrete(expand=c(0,0)) +
  coord_flip() +
  theme(axis.ticks.x = element_blank(),
        axis.text.x = element_blank(),
        axis.ticks.y = element_blank(),
        legend.position = "bottom")
```




```{r pre_fit_table, fig.width = 10, fig.height = 5}


# demean data first
nonfix_60_2006 %>%
  filter(year >= 2007, year <= 2014) %>%
  group_by(leaid) %>%
  summarise(across(c(math, read, sped, attd), ~ mean(., na.rm = T), .names = "{.col}_pre_mean")) %>%
  right_join(nonfix_60_2006, by = "leaid") %>%
  mutate(math_demeaned = math - math_pre_mean,
         read_demeaned = read - read_pre_mean,
         sped_demeaned = sped - sped_pre_mean,
         attd_demeaned = attd - attd_pre_mean
  ) -> demeaned_dat

# compute_rmspe <- function(X, trt, feff, wts) {
#   t0 <- ncol(X)
#   X1 <- (X - feff[,1:t0])[trt == 1,]
#   X0 <- (X - feff[,1:t0])[trt == 0,]
#   sqrt(mean((X1 - t(X0) %*% wts)^2))
# }

compute_rmspe <- function(X, trt, wts) {
  t0 <- ncol(X)
  X1 <- X[trt == 1,1:t0]
  X0 <- X[trt == 0,1:t0]
  mean((X1 - t(X0) %*% wts)^2) / mean((X1 - colMeans(X0))^2)
}


syn_sep_math_d <- augsynth(math_demeaned ~ treat, leaid, year, demeaned_dat %>% filter(year >= 2007),
                         progfunc="None", fixedeff=F, scm=T)

syn_sep_read_d <- augsynth(read_demeaned ~ treat, leaid, year,demeaned_dat  %>% filter(year >= 2007),
                         progfunc="None", fixedeff=F, scm=T)

syn_sep_sped_d <- augsynth(sped_demeaned ~ treat, leaid ,year, demeaned_dat  %>% filter(year >= 2007),
                         progfunc="None", fixedeff=F, scm=T)

syn_sep_attd_d <- augsynth(attd_demeaned ~ treat, leaid ,year, demeaned_dat  %>% filter(year >= 2009),
                         progfunc="None", fixedeff=F, scm=T)

syn_cat_nomath_d <- augsynth(read_demeaned  + sped_demeaned + attd_demeaned ~ treat, leaid ,year,demeaned_dat %>% filter(year >= 2007),
                    progfunc = "None", scm = T,combine_method = "concat" , fixedeff = F)
syn_cat_noread_d <- augsynth(math_demeaned  + sped_demeaned + attd_demeaned ~ treat, leaid ,year,demeaned_dat %>% filter(year >= 2007),
                    progfunc = "None", scm = T,combine_method = "concat" , fixedeff = F)
syn_cat_noattd_d <- augsynth(read_demeaned  + sped_demeaned + math_demeaned ~ treat, leaid ,year,demeaned_dat %>% filter(year >= 2007),
                    progfunc = "None", scm = T,combine_method = "concat" , fixedeff = F)
syn_cat_nosped_d <- augsynth(read_demeaned + sped_demeaned + math_demeaned + attd_demeaned ~ treat, leaid ,year,demeaned_dat %>% filter(year >= 2007),
                    progfunc = "None", scm = T,combine_method = "concat" , fixedeff = F)


syn_cat_nomath_nosped_d <- augsynth(read_demeaned  + attd_demeaned ~ treat, leaid ,year,demeaned_dat %>% filter(year >= 2007),
                    progfunc = "None", scm = T,combine_method = "concat" , fixedeff = F)
syn_cat_noread_nosped_d <- augsynth(math_demeaned  + attd_demeaned ~ treat, leaid ,year,demeaned_dat %>% filter(year >= 2007),
                    progfunc = "None", scm = T,combine_method = "concat" , fixedeff = F)
syn_cat_noattd_nosped_d <- augsynth(read_demeaned  + math_demeaned ~ treat, leaid ,year,demeaned_dat %>% filter(year >= 2007),
                    progfunc = "None", scm = T,combine_method = "concat" , fixedeff = F)


syn_avg_nomath_d <-  augsynth(read_demeaned  + sped_demeaned + attd_demeaned ~ treat, leaid ,year,demeaned_dat %>% filter(year >= 2007),
                     progfunc="None", fixedeff=F, scm=T,combine_method = "avg")
syn_avg_noread_d <-  augsynth(math_demeaned  + sped_demeaned + attd_demeaned ~ treat, leaid ,year,demeaned_dat %>% filter(year >= 2007),
                     progfunc="None", fixedeff=F, scm=T,combine_method = "avg")
syn_avg_noattd_d <-  augsynth(read_demeaned  + sped_demeaned + math_demeaned ~ treat, leaid ,year,demeaned_dat %>% filter(year >= 2007),
                     progfunc="None", fixedeff=F, scm=T,combine_method = "avg")
syn_avg_nosped_d <-  augsynth(read_demeaned + sped_demeaned + math_demeaned + attd_demeaned ~ treat, leaid ,year,demeaned_dat %>% filter(year >= 2007),
                     progfunc="None", fixedeff=F, scm=T,combine_method = "avg")

syn_avg_nomath_nosped_d <-  augsynth(read_demeaned  + attd_demeaned ~ treat, leaid ,year,demeaned_dat %>% filter(year >= 2007),
                     progfunc="None", fixedeff=F, scm=T,combine_method = "avg")
syn_avg_noread_nosped_d <-  augsynth(math_demeaned  + attd_demeaned ~ treat, leaid ,year,demeaned_dat %>% filter(year >= 2007),
                     progfunc="None", fixedeff=F, scm=T,combine_method = "avg")
syn_avg_noattd_nosped_d <-  augsynth(read_demeaned  + math_demeaned ~ treat, leaid ,year,demeaned_dat %>% filter(year >= 2007),
                     progfunc="None", fixedeff=F, scm=T,combine_method = "avg")


syn_both_nomath_d <-  augsynth(read_demeaned  + sped_demeaned + attd_demeaned ~ treat, leaid ,year,demeaned_dat %>% filter(year >= 2007),
                     progfunc="None", fixedeff=F, scm=T,combine_method = "avg_concat", nu = 0.5)
syn_both_noread_d <-  augsynth(math_demeaned  + sped_demeaned + attd_demeaned ~ treat, leaid ,year,demeaned_dat %>% filter(year >= 2007),
                     progfunc="None", fixedeff=F, scm=T,combine_method = "avg_concat", nu = 0.5)
syn_both_noattd_d <-  augsynth(read_demeaned  + sped_demeaned + math_demeaned ~ treat, leaid ,year,demeaned_dat %>% filter(year >= 2007),
                     progfunc="None", fixedeff=F, scm=T,combine_method = "avg_concat", nu = 0.5)
syn_both_nosped_d <-  augsynth(read_demeaned + sped_demeaned + math_demeaned + attd_demeaned ~ treat, leaid ,year,demeaned_dat %>% filter(year >= 2007),
                     progfunc="None", fixedeff=F, scm=T,combine_method = "avg_concat", nu = 0.5)

syn_both_nomath_nosped_d <-  augsynth(read_demeaned  + attd_demeaned ~ treat, leaid ,year,demeaned_dat %>% filter(year >= 2007),
                     progfunc="None", fixedeff=F, scm=T,combine_method = "avg_concat", nu = 0.5)
syn_both_noread_nosped_d <-  augsynth(math_demeaned  + attd_demeaned ~ treat, leaid ,year,demeaned_dat %>% filter(year >= 2007),
                     progfunc="None", fixedeff=F, scm=T,combine_method = "avg_concat", nu = 0.5)
syn_both_noattd_nosped_d <-  augsynth(read_demeaned  + math_demeaned ~ treat, leaid ,year,demeaned_dat %>% filter(year >= 2007),
                     progfunc="None", fixedeff=F, scm=T,combine_method = "avg_concat", nu = 0.5)

syn_both_d <-  augsynth(math_demeaned + read_demeaned  + sped_demeaned + attd_demeaned ~ treat, leaid ,year,demeaned_dat %>% filter(year >= 2007),
                     progfunc="None", fixedeff=F, scm=T,combine_method = "avg_concat", nu = 0.5)

outcome_data_list <- list(math = list(X = syn_sep_math_d$data$X,
                                      trt = syn_sep_math_d$data$trt,
                                      feff = syn_sep_math_d$mhat,
                                      wts = syn_sep_math_d$weights),
                          read = list(X = syn_sep_read_d$data$X,
                                      trt = syn_sep_read_d$data$trt,
                                      feff = syn_sep_read_d$mhat,
                                      wts = syn_sep_read_d$weights),
                          sped = list(X = syn_sep_sped_d$data$X,
                                      trt = syn_sep_sped_d$data$trt,
                                      feff = syn_sep_sped_d$mhat,
                                      wts = syn_sep_sped_d$weights),
                          attd = list(X = syn_sep_attd_d$data$X,
                                      trt = syn_sep_attd_d$data$trt,
                                      feff = syn_sep_attd_d$mhat,
                                      wts = syn_sep_attd_d$weights),
                          cat_nomath = list(wts = syn_cat_nomath_d$weights),
                          cat_noread = list(wts = syn_cat_noread_d$weights),
                          cat_noattd = list(wts = syn_cat_noattd_d$weights),
                          cat_nosped = list(wts = syn_cat_nosped_d$weights),
                          nosped_cat_nomath = list(wts = syn_cat_nomath_nosped_d$weights),
                          nosped_cat_noread = list(wts = syn_cat_noread_nosped_d$weights),
                          nosped_cat_noattd = list(wts = syn_cat_noattd_nosped_d$weights),
                          avg_nomath = list(wts = syn_avg_nomath_d$weights),
                          avg_noread = list(wts = syn_avg_noread_d$weights),
                          avg_noattd = list(wts = syn_avg_noattd_d$weights),
                          avg_nosped = list(wts = syn_avg_nosped_d$weights),
                          nosped_avg_nomath = list(wts = syn_avg_nomath_nosped_d$weights),
                          nosped_avg_noread = list(wts = syn_avg_noread_nosped_d$weights),
                          nosped_avg_noattd = list(wts = syn_avg_noattd_nosped_d$weights),
                          both_nomath = list(wts = syn_both_nomath_d$weights),
                          both_noread = list(wts = syn_both_noread_d$weights),
                          both_noattd = list(wts = syn_both_noattd_d$weights),
                          both_nosped = list(wts = syn_both_nosped_d$weights),
                          nosped_both_nomath = list(wts = syn_both_nomath_nosped_d$weights),
                          nosped_both_noread = list(wts = syn_both_noread_nosped_d$weights),
                          nosped_both_noattd = list(wts = syn_both_noattd_nosped_d$weights),
                          both = list(wts = syn_both_d$weights)
                          )
outcome_list <- c("math", "read", "sped", "attd")
weight_list <- c(outcome_list,
                #  "cat_nomath", "cat_noread", "cat_noattd", "cat_nosped",
                #  "nosped_cat_nomath", "nosped_cat_noread", "nosped_cat_noattd",
                #  "avg_nomath", "avg_noread", "avg_noattd", "avg_nosped",
                #  "nosped_avg_nomath", "nosped_avg_noread", "nosped_avg_noattd",
                 "both_nomath", "both_noread", "both_noattd", "both_nosped",
                 "nosped_both_nomath", "nosped_both_noread", "nosped_both_noattd",
                 "both")
expand.grid(weight = weight_list, outcome = outcome_list) %>%
  mutate(rmspe = apply(., 1, function(x) {
    wts <- outcome_data_list[[x[1]]]$wts
    X <- outcome_data_list[[x[2]]]$X
    trt <- outcome_data_list[[x[2]]]$trt
    compute_rmspe(X, trt, wts)
  })) %>%
  mutate(type = case_when(
    weight %in% c("math", "read", "sped", "attd") ~ weight,
    str_detect(weight, "nosped_cat_") ~ "Concatenated (w/o Spec. Needs)",
    str_detect(weight, "nosped_avg_") ~ "Averaged (w/o Spec. Needs)",
    str_detect(weight, "nosped_both_") ~ "Combined (w/o Focal Outome & Spec. Needs)",
    str_detect(weight, "cat_") ~ "Concatenated",
    str_detect(weight, "avg_") ~ "Averaged",
    str_detect(weight, "both_") ~ "Combined (w/o Focal Outcome)",
    str_detect(weight, "both") ~ "Combined"
    ),
    dropped_outcome = case_when(
    weight %in% c("math", "read", "sped", "attd") ~ NA,
    str_detect(weight, "nosped_cat_") ~ str_replace(weight, "nosped_cat_no", ""),
    str_detect(weight, "nosped_avg_") ~ str_replace(weight, "nosped_avg_no", ""),
    str_detect(weight, "nosped_both_") ~ str_replace(weight, "nosped_both_no", ""),
    str_detect(weight, "cat_") ~ str_replace(weight, "cat_no", ""),
    str_detect(weight, "avg_") ~ str_replace(weight, "avg_no", ""),
    str_detect(weight, "both_") ~ str_replace(weight, "both_no", ""),
    TRUE ~ NA
    )) %>%
  filter(is.na(dropped_outcome) | dropped_outcome == outcome) %>%
  mutate(type = fct_recode(type,
                                      `Math Achievement` = 'math',
                                      `Reading Achievement`='read',
                                      `Special Needs`='sped',
                                      `Student Attendance`='attd'),
         type = fct_relevel(type, rev(c("Math Achievement", "Reading Achievement", "Student Attendance", "Special Needs", "Concatenated", "Concatenated (w/o Spec. Needs)", "Averaged", "Averaged (w/o Spec. Needs)",
         "Combined", "Combined (w/o Focal Outcome)", "Combined (w/o Focal Outome & Spec. Needs)"))),
         outcome = fct_recode(outcome,
                          `Math Achievement` = 'math',
                          `Reading Achievement`='read',
                          `Special Needs`='sped',
                          `Student Attendance`='attd'),
          outcome = fct_relevel(outcome, c("Math Achievement", "Reading Achievement", "Student Attendance", "Special Needs")),
        #  separate = ifelse(is.na(dropped_outcome), "Separate Weights", "Combined Weights")
        separate = ifelse(str_detect(type, "Combined"), "Combined Weights", "Separate Weights")
         ) %>%
  ggplot(aes(x = rmspe, y = type)) +
  geom_point() +
  geom_vline(xintercept = 0, lty = 2) +
  facet_grid(separate ~ outcome, scales = "free") +
  scale_x_continuous("MSPE relative to uniform weights (holding out focal outcome)", labels = scales::percent) +
  ylab("Weight Objective") + 
  theme_bw()


```


```{r syn_both_weights}

# check max for and cat

data.frame(leaid = as.integer(rownames(syn_sep_math$weights)),
           weight = syn_sep_math$weights) %>%
  left_join(nonfix_60_2006 %>% distinct(leaid, name)) %>%
  arrange(desc(weight)) %>%
  filter(weight > 0.001)

data.frame(leaid = as.integer(rownames(syn_cat$weights)),
           weight = syn_cat$weights) %>%
  left_join(nonfix_60_2006 %>% distinct(leaid, name)) %>%
  arrange(desc(weight)) %>%
  filter(weight > 0.001)

data.frame(leaid = as.integer(rownames(syn_both$weights)),
           weight = syn_both$weights) %>%
  left_join(nonfix_60_2006 %>% distinct(leaid, name)) %>%
  arrange(desc(weight)) %>%
  filter(weight > 0.001) %>%
  select(name, weight) %>%
  rename(`District Name`=name, `Combined SCM Weight`=weight) %>%
  xtable
```